{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deep Q-Network implementation.\n",
    "\n",
    "This homework shamelessly demands you to implement a DQN - an approximate q-learning algorithm with experience replay and target networks - and see if it works any better this way.\n",
    "\n",
    "Original paper:\n",
    "https://arxiv.org/pdf/1312.5602.pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**This notebook is given for debug.** The main task is in the other notebook (**homework_pytorch_main**). The tasks are similar and share most of the code. The main difference is in environments. In main notebook it can take some 2 hours for the agent to start improving so it seems reasonable to launch the algorithm on a simpler env first. Here it is CartPole and it will train in several minutes.\n",
    "\n",
    "**We suggest the following pipeline:** First implement debug notebook then implement the main one.\n",
    "\n",
    "**About evaluation:** All points are given for the main notebook with one exception: if agent fails to beat the threshold in main notebook you can get 1 pt (instead of 3 pts) for beating the threshold in debug notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # in google colab uncomment this\n",
    "\n",
    "# import os\n",
    "\n",
    "# os.system('apt-get install -y xvfb')\n",
    "# os.system('wget https://raw.githubusercontent.com/yandexdataschool/Practical_DL/fall18/xvfb -O ../xvfb')\n",
    "# os.system('apt-get install -y python-opengl ffmpeg')\n",
    "# os.system('pip install pyglet==1.2.4')\n",
    "\n",
    "# os.system('python -m pip install -U pygame --user')\n",
    "\n",
    "# prefix = 'https://raw.githubusercontent.com/yandexdataschool/Practical_RL/master/week04_approx_rl/'\n",
    "\n",
    "# os.system('wget ' + prefix + 'atari_wrappers.py')\n",
    "# os.system('wget ' + prefix + 'utils.py')\n",
    "# os.system('wget ' + prefix + 'replay_buffer.py')\n",
    "# os.system('wget ' + prefix + 'framebuffer.py')\n",
    "\n",
    "# print('setup complete')\n",
    "\n",
    "# XVFB will be launched if you run on a server\n",
    "import os\n",
    "if type(os.environ.get(\"DISPLAY\")) is not str or len(os.environ.get(\"DISPLAY\")) == 0:\n",
    "    !bash ../xvfb start\n",
    "    os.environ['DISPLAY'] = ':1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Frameworks__ - we'll accept this homework in any deep learning framework. This particular notebook was designed for pytoch, but you find it easy to adapt it to almost any python-based deep learning framework."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "import numpy as np\n",
    "import torch\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CartPole again"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ENV_NAME = 'CartPole-v1'\n",
    "\n",
    "def make_env(seed=None):\n",
    "    # CartPole is wrapped with a time limit wrapper by default\n",
    "    env = gym.make(ENV_NAME).unwrapped\n",
    "    if seed is not None:\n",
    "        env.seed(seed)\n",
    "    return env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = make_env()\n",
    "env.reset()\n",
    "state_shape, n_actions = env.observation_space.shape, env.action_space.n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building a network"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now need to build a neural network that can map observations to state q-values.\n",
    "The model does not have to be huge yet. 1-2 hidden layers with < 200 neurons and ReLU activation will probably be enough. Batch normalization and dropout can spoil everything here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "# those who have a GPU but feel unfair to use it can uncomment:\n",
    "# device = torch.device('cpu')\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DQNAgent(nn.Module):\n",
    "    def __init__(self, state_shape, n_actions, epsilon=0):\n",
    "\n",
    "        super().__init__()\n",
    "        self.epsilon = epsilon\n",
    "        self.n_actions = n_actions\n",
    "        self.state_shape = state_shape\n",
    "        # Define your network body here. Please make sure agent is fully contained here\n",
    "        assert len(state_shape) == 1\n",
    "        state_dim = state_shape[0]\n",
    "        <YOUR CODE>\n",
    "\n",
    "        \n",
    "    def forward(self, state_t):\n",
    "        \"\"\"\n",
    "        takes agent's observation (tensor), returns qvalues (tensor)\n",
    "        :param state_t: a batch states, shape = [batch_size, *state_dim=4]\n",
    "        \"\"\"\n",
    "        # Use your network to compute qvalues for given state\n",
    "        qvalues = <YOUR CODE>\n",
    "\n",
    "        assert qvalues.requires_grad, \"qvalues must be a torch tensor with grad\"\n",
    "        assert len(\n",
    "            qvalues.shape) == 2 and qvalues.shape[0] == state_t.shape[0] and qvalues.shape[1] == n_actions\n",
    "\n",
    "        return qvalues\n",
    "\n",
    "    def get_qvalues(self, states):\n",
    "        \"\"\"\n",
    "        like forward, but works on numpy arrays, not tensors\n",
    "        \"\"\"\n",
    "        model_device = next(self.parameters()).device\n",
    "        states = torch.tensor(states, device=model_device, dtype=torch.float32)\n",
    "        qvalues = self.forward(states)\n",
    "        return qvalues.data.cpu().numpy()\n",
    "\n",
    "    def sample_actions(self, qvalues):\n",
    "        \"\"\"pick actions given qvalues. Uses epsilon-greedy exploration strategy. \"\"\"\n",
    "        epsilon = self.epsilon\n",
    "        batch_size, n_actions = qvalues.shape\n",
    "\n",
    "        random_actions = np.random.choice(n_actions, size=batch_size)\n",
    "        best_actions = qvalues.argmax(axis=-1)\n",
    "\n",
    "        should_explore = np.random.choice(\n",
    "            [0, 1], batch_size, p=[1-epsilon, epsilon])\n",
    "        return np.where(should_explore, random_actions, best_actions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent = DQNAgent(state_shape, n_actions, epsilon=0.5).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's try out our agent to see if it raises any errors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(env, agent, n_games=1, greedy=False, t_max=10000):\n",
    "    \"\"\" Plays n_games full games. If greedy, picks actions as argmax(qvalues). Returns mean reward. \"\"\"\n",
    "    rewards = []\n",
    "    for _ in range(n_games):\n",
    "        s = env.reset()\n",
    "        reward = 0\n",
    "        for _ in range(t_max):\n",
    "            qvalues = agent.get_qvalues([s])\n",
    "            action = qvalues.argmax(axis=-1)[0] if greedy else agent.sample_actions(qvalues)[0]\n",
    "            s, r, done, _ = env.step(action)\n",
    "            reward += r\n",
    "            if done:\n",
    "                break\n",
    "\n",
    "        rewards.append(reward)\n",
    "    return np.mean(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(env, agent, n_games=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experience replay\n",
    "For this assignment, we provide you with experience replay buffer. If you implemented experience replay buffer in last week's assignment, you can copy-paste it here in main notebook **to get 2 bonus points**.\n",
    "\n",
    "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/exp_replay.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### The interface is fairly simple:\n",
    "* `exp_replay.add(obs, act, rw, next_obs, done)` - saves (s,a,r,s',done) tuple into the buffer\n",
    "* `exp_replay.sample(batch_size)` - returns observations, actions, rewards, next_observations and is_done for `batch_size` random samples.\n",
    "* `len(exp_replay)` - returns number of elements stored in replay buffer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from replay_buffer import ReplayBuffer\n",
    "exp_replay = ReplayBuffer(10)\n",
    "\n",
    "for _ in range(30):\n",
    "    exp_replay.add(env.reset(), env.action_space.sample(),\n",
    "                   1.0, env.reset(), done=False)\n",
    "\n",
    "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n",
    "    5)\n",
    "\n",
    "assert len(exp_replay) == 10, \"experience replay size should be 10 because that's what maximum capacity is\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def play_and_record(initial_state, agent, env, exp_replay, n_steps=1):\n",
    "    \"\"\"\n",
    "    Play the game for exactly n steps, record every (s,a,r,s', done) to replay buffer. \n",
    "    Whenever game ends, add record with done=True and reset the game.\n",
    "    It is guaranteed that env has done=False when passed to this function.\n",
    "\n",
    "    PLEASE DO NOT RESET ENV UNLESS IT IS \"DONE\"\n",
    "\n",
    "    :returns: return sum of rewards over time and the state in which the env stays\n",
    "    \"\"\"\n",
    "    s = initial_state\n",
    "    sum_rewards = 0\n",
    "\n",
    "    # Play the game for n_steps as per instructions above\n",
    "    <YOUR CODE >\n",
    "\n",
    "    return sum_rewards, s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# testing your code.\n",
    "exp_replay = ReplayBuffer(2000)\n",
    "\n",
    "state = env.reset()\n",
    "play_and_record(state, agent, env, exp_replay, n_steps=1000)\n",
    "\n",
    "# if you're using your own experience replay buffer, some of those tests may need correction.\n",
    "# just make sure you know what your code does\n",
    "assert len(exp_replay) == 1000, \"play_and_record should have added exactly 1000 steps, \"\\\n",
    "                                 \"but instead added %i\" % len(exp_replay)\n",
    "is_dones = list(zip(*exp_replay._storage))[-1]\n",
    "\n",
    "assert 0 < np.mean(is_dones) < 0.1, \"Please make sure you restart the game whenever it is 'done' and record the is_done correctly into the buffer.\"\\\n",
    "                                    \"Got %f is_done rate over %i steps. [If you think it's your tough luck, just re-run the test]\" % (\n",
    "                                        np.mean(is_dones), len(exp_replay))\n",
    "\n",
    "for _ in range(100):\n",
    "    obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n",
    "        10)\n",
    "    assert obs_batch.shape == next_obs_batch.shape == (10,) + state_shape\n",
    "    assert act_batch.shape == (\n",
    "        10,), \"actions batch should have shape (10,) but is instead %s\" % str(act_batch.shape)\n",
    "    assert reward_batch.shape == (\n",
    "        10,), \"rewards batch should have shape (10,) but is instead %s\" % str(reward_batch.shape)\n",
    "    assert is_done_batch.shape == (\n",
    "        10,), \"is_done batch should have shape (10,) but is instead %s\" % str(is_done_batch.shape)\n",
    "    assert [int(i) in (0, 1)\n",
    "            for i in is_dones], \"is_done should be strictly True or False\"\n",
    "    assert [\n",
    "        0 <= a < n_actions for a in act_batch], \"actions should be within [0, n_actions]\"\n",
    "\n",
    "print(\"Well done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Target networks\n",
    "\n",
    "We also employ the so called \"target network\" - a copy of neural network weights to be used for reference Q-values:\n",
    "\n",
    "The network itself is an exact copy of agent network, but it's parameters are not trained. Instead, they are moved here from agent's actual network every so often.\n",
    "\n",
    "$$ Q_{reference}(s,a) = r + \\gamma \\cdot \\max _{a'} Q_{target}(s',a') $$\n",
    "\n",
    "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/target_net.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_network = DQNAgent(agent.state_shape, agent.n_actions, epsilon=0.5).to(device)\n",
    "# This is how you can load weights from agent into target network\n",
    "target_network.load_state_dict(agent.state_dict())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Learning with... Q-learning\n",
    "Here we write a function similar to `agent.update` from tabular q-learning."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute Q-learning TD error:\n",
    "\n",
    "$$ L = { 1 \\over N} \\sum_i [ Q_{\\theta}(s,a) - Q_{reference}(s,a) ] ^2 $$\n",
    "\n",
    "With Q-reference defined as\n",
    "\n",
    "$$ Q_{reference}(s,a) = r(s,a) + \\gamma \\cdot max_{a'} Q_{target}(s', a') $$\n",
    "\n",
    "Where\n",
    "* $Q_{target}(s',a')$ denotes q-value of next state and next action predicted by __target_network__\n",
    "* $s, a, r, s'$ are current state, action, reward and next state respectively\n",
    "* $\\gamma$ is a discount factor defined two cells above.\n",
    "\n",
    "\n",
    "__Note 1:__ there's an example input below. Feel free to experiment with it before you write the function.\n",
    "\n",
    "__Note 2:__ compute_td_loss is a source of 99% of bugs in this homework. If reward doesn't improve, it often helps to go through it line by line [with a rubber duck](https://rubberduckdebugging.com/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_td_loss(states, actions, rewards, next_states, is_done,\n",
    "                    agent, target_network,\n",
    "                    gamma=0.99,\n",
    "                    check_shapes=False,\n",
    "                    device=device):\n",
    "    \"\"\" Compute td loss using torch operations only. Use the formulae above. \"\"\"\n",
    "    states = torch.tensor(states, device=device, dtype=torch.float)    # shape: [batch_size, *state_shape]\n",
    "\n",
    "    # for some torch reason should not make actions a tensor\n",
    "    actions = torch.tensor(actions, device=device, dtype=torch.long)    # shape: [batch_size]\n",
    "    rewards = torch.tensor(rewards, device=device, dtype=torch.float)  # shape: [batch_size]\n",
    "    # shape: [batch_size, *state_shape]\n",
    "    next_states = torch.tensor(next_states, device=device, dtype=torch.float)\n",
    "    is_done = torch.tensor(\n",
    "        is_done.astype('float32'),\n",
    "        device=device,\n",
    "        dtype=torch.float\n",
    "    )  # shape: [batch_size]\n",
    "    is_not_done = 1 - is_done\n",
    "\n",
    "    # get q-values for all actions in current states\n",
    "    predicted_qvalues = agent(states)\n",
    "\n",
    "    # compute q-values for all actions in next states\n",
    "    predicted_next_qvalues = target_network(next_states)\n",
    "    \n",
    "    # select q-values for chosen actions\n",
    "    predicted_qvalues_for_actions = predicted_qvalues[range(\n",
    "        len(actions)), actions]\n",
    "\n",
    "    # compute V*(next_states) using predicted next q-values\n",
    "    next_state_values = <YOUR CODE>\n",
    "\n",
    "    assert next_state_values.dim(\n",
    "    ) == 1 and next_state_values.shape[0] == states.shape[0], \"must predict one value per state\"\n",
    "\n",
    "    # compute \"target q-values\" for loss - it's what's inside square parentheses in the above formula.\n",
    "    # at the last state use the simplified formula: Q(s,a) = r(s,a) since s' doesn't exist\n",
    "    # you can multiply next state values by is_not_done to achieve this.\n",
    "    target_qvalues_for_actions = <YOUR CODE>\n",
    "\n",
    "    # mean squared error loss to minimize\n",
    "    loss = torch.mean((predicted_qvalues_for_actions -\n",
    "                       target_qvalues_for_actions.detach()) ** 2)\n",
    "\n",
    "    if check_shapes:\n",
    "        assert predicted_next_qvalues.data.dim(\n",
    "        ) == 2, \"make sure you predicted q-values for all actions in next state\"\n",
    "        assert next_state_values.data.dim(\n",
    "        ) == 1, \"make sure you computed V(s') as maximum over just the actions axis and not all axes\"\n",
    "        assert target_qvalues_for_actions.data.dim(\n",
    "        ) == 1, \"there's something wrong with target q-values, they must be a vector\"\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Sanity checks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch = exp_replay.sample(\n",
    "    10)\n",
    "\n",
    "loss = compute_td_loss(obs_batch, act_batch, reward_batch, next_obs_batch, is_done_batch,\n",
    "                       agent, target_network,\n",
    "                       gamma=0.99, check_shapes=True)\n",
    "loss.backward()\n",
    "\n",
    "assert loss.requires_grad and tuple(loss.data.size()) == (\n",
    "    ), \"you must return scalar loss - mean over batch\"\n",
    "assert np.any(next(agent.parameters()).grad.data.cpu().numpy() !=\n",
    "              0), \"loss must be differentiable w.r.t. network weights\"\n",
    "assert np.all(next(target_network.parameters()).grad is None), \"target network should not have grads\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Main loop\n",
    "\n",
    "It's time to put everything together and see if it learns anything."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import trange\n",
    "from IPython.display import clear_output\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = <your favourite random seed>\n",
    "random.seed(seed)\n",
    "np.random.seed(seed)\n",
    "torch.manual_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = make_env(seed)\n",
    "state_dim = env.observation_space.shape\n",
    "n_actions = env.action_space.n\n",
    "state = env.reset()\n",
    "\n",
    "agent = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n",
    "target_network = DQNAgent(state_dim, n_actions, epsilon=1).to(device)\n",
    "target_network.load_state_dict(agent.state_dict())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "exp_replay = ReplayBuffer(10**4)\n",
    "for i in range(100):\n",
    "    if not utils.is_enough_ram(min_available_gb=0.1):\n",
    "        print(\"\"\"\n",
    "            Less than 100 Mb RAM available. \n",
    "            Make sure the buffer size in not too huge.\n",
    "            Also check, maybe other processes consume RAM heavily.\n",
    "            \"\"\"\n",
    "             )\n",
    "        break\n",
    "    play_and_record(state, agent, env, exp_replay, n_steps=10**2)\n",
    "    if len(exp_replay) == 10**4:\n",
    "        break\n",
    "print(len(exp_replay))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps_per_epoch = 1\n",
    "batch_size = 32\n",
    "total_steps = 4 * 10**4\n",
    "decay_steps = 1 * 10**4\n",
    "\n",
    "opt = torch.optim.Adam(agent.parameters(), lr=1e-4)\n",
    "\n",
    "init_epsilon = 1\n",
    "final_epsilon = 0.1\n",
    "\n",
    "loss_freq = 20\n",
    "refresh_target_network_freq = 100\n",
    "eval_freq = 1000\n",
    "\n",
    "max_grad_norm = 5000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mean_rw_history = []\n",
    "td_loss_history = []\n",
    "grad_norm_history = []\n",
    "initial_state_v_history = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state = env.reset()\n",
    "for step in trange(total_steps + 1):\n",
    "    if not utils.is_enough_ram():\n",
    "        print('less that 100 Mb RAM available, freezing')\n",
    "        print('make sure everything is ok and make KeyboardInterrupt to continue')\n",
    "        try:\n",
    "            while True:\n",
    "                pass\n",
    "        except KeyboardInterrupt:\n",
    "            pass\n",
    "\n",
    "    agent.epsilon = utils.linear_decay(init_epsilon, final_epsilon, step, decay_steps)\n",
    "\n",
    "    # play\n",
    "    _, state = play_and_record(state, agent, env, exp_replay, timesteps_per_epoch)\n",
    "\n",
    "    # train\n",
    "    < sample batch_size of data from experience replay >\n",
    "\n",
    "    loss = < compute TD loss >\n",
    "\n",
    "    loss.backward()\n",
    "    grad_norm = nn.utils.clip_grad_norm_(agent.parameters(), max_grad_norm)\n",
    "    opt.step()\n",
    "    opt.zero_grad()\n",
    "\n",
    "    if step % loss_freq == 0:\n",
    "        td_loss_history.append(loss.data.cpu().item())\n",
    "        grad_norm_history.append(grad_norm)\n",
    "\n",
    "    if step % refresh_target_network_freq == 0:\n",
    "        # Load agent weights into target_network\n",
    "        <YOUR CODE >\n",
    "\n",
    "    if step % eval_freq == 0:\n",
    "        # eval the agent\n",
    "        mean_rw_history.append(evaluate(\n",
    "            make_env(seed=step), agent, n_games=3, greedy=True, t_max=1000)\n",
    "        )\n",
    "        initial_state_q_values = agent.get_qvalues(\n",
    "            [make_env(seed=step).reset()]\n",
    "        )\n",
    "        initial_state_v_history.append(np.max(initial_state_q_values))\n",
    "\n",
    "        clear_output(True)\n",
    "        print(\"buffer size = %i, epsilon = %.5f\" %\n",
    "              (len(exp_replay), agent.epsilon))\n",
    "\n",
    "        plt.figure(figsize=[16, 9])\n",
    "        plt.subplot(2, 2, 1)\n",
    "        plt.title(\"Mean reward per episode\")\n",
    "        plt.plot(mean_rw_history)\n",
    "        plt.grid()\n",
    "\n",
    "        assert not np.isnan(td_loss_history[-1])\n",
    "        plt.subplot(2, 2, 2)\n",
    "        plt.title(\"TD loss history (smoothened)\")\n",
    "        plt.plot(utils.smoothen(td_loss_history))\n",
    "        plt.grid()\n",
    "\n",
    "        plt.subplot(2, 2, 3)\n",
    "        plt.title(\"Initial state V\")\n",
    "        plt.plot(initial_state_v_history)\n",
    "        plt.grid()\n",
    "\n",
    "        plt.subplot(2, 2, 4)\n",
    "        plt.title(\"Grad norm history (smoothened)\")\n",
    "        plt.plot(utils.smoothen(grad_norm_history))\n",
    "        plt.grid()\n",
    "\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "final_score = evaluate(\n",
    "  make_env(),\n",
    "  agent, n_games=30, greedy=True, t_max=1000\n",
    ")\n",
    "print('final score:', final_score)\n",
    "assert final_score > 300, 'not good enough for DQN'\n",
    "print('Well done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
